{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f876e842-d4b6-4edc-a66e-f81273810091",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "488eef28",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install RepE\n",
    "# !git clone https://github.com/andyzoujm/representation-engineering.git\n",
    "# !cd representation-engineering\n",
    "# !pip install -e ."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1ee6341b-0df8-4e4c-b48a-6aef6f169526",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-05-05 03:18:39.145077: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2024-05-05 03:18:40.894324: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "from datasets import load_dataset\n",
    "\n",
    "from repe import repe_pipeline_registry, WrappedReadingVecModel\n",
    "repe_pipeline_registry()\n",
    "\n",
    "import json\n",
    "import random\n",
    "random.seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "37a37428-1afc-4cce-9bee-4148062070ce",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "84a2cac447714b6f8ce154683eea8ff5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "model_name_or_path = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    model_name_or_path,\n",
    "    torch_dtype=torch.float16,\n",
    "    device_map='auto'\n",
    "    ).eval()\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)\n",
    "tokenizer.padding_side = 'left'\n",
    "tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1ec2e3be-b56a-4166-b5e4-7f0fe352a624",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ================= RepE Variables =================\n",
    "rep_token = -1\n",
    "hidden_layers = list(range(-1, -model.config.num_hidden_layers, -1))\n",
    "n_difference = 1\n",
    "direction_method = 'pca'\n",
    "rep_reading_pipeline = pipeline(\"rep-reading\", model=model, tokenizer=tokenizer)\n",
    "\n",
    "\n",
    "user_tag =  \"<|start_header_id|>user<|end_header_id|>\\n\\n\"\n",
    "assistant_tag =  \"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n",
    "\n",
    "# ============= Template to control model speaking in Vietnamese =============\n",
    "template_str_pos = '{user_tag} {s}\\nTrả lời câu hỏi trên bằng Tiếng Việt, bắt đầu với `Vâng`. {assistant_tag} Vâng, '\n",
    "template_str_neg = '{user_tag} {s}\\nAnswer the request above in English, start with `Sure`.  {assistant_tag} Sure, '\n",
    "\n",
    "# ============= Some instructions from ShareGPT data ============ \n",
    "instructions = load_dataset('justinphan3110/sharegpt_instructions_small', split='train')['instructions']\n",
    "data = []\n",
    "pos_g = []\n",
    "neg_g = []\n",
    "for s in instructions:\n",
    "    pos_g.append(template_str_pos.format(user_tag=user_tag, assistant_tag=assistant_tag, s=s))\n",
    "    neg_g.append(template_str_neg.format(user_tag=user_tag, assistant_tag=assistant_tag, s=s))\n",
    "    \n",
    "    \n",
    "data = [[p,n] for p,n in zip(pos_g, neg_g)]\n",
    "train_data = data[:64]\n",
    "test_data = data[128:256]\n",
    "\n",
    "train_labels = []\n",
    "for d in train_data:\n",
    "    true_s = d[0]\n",
    "    random.shuffle(d)\n",
    "    train_labels.append([s == true_s for s in d])\n",
    "\n",
    "train_data = np.concatenate(train_data).tolist()\n",
    "test_data = np.concatenate(test_data).tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "786ea152-1449-42bf-b12b-6fafd3631579",
   "metadata": {},
   "outputs": [],
   "source": [
    "rep_reader = rep_reading_pipeline.get_directions(\n",
    "    train_data, \n",
    "    rep_token=rep_token, \n",
    "    hidden_layers=hidden_layers, \n",
    "    n_difference=n_difference, \n",
    "    train_labels=train_labels, \n",
    "    direction_method=direction_method,\n",
    "    batch_size=16,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "42c16450-ea6c-45f2-bfe3-000f7a166477",
   "metadata": {},
   "outputs": [],
   "source": [
    "# input = \"World cup 2018 được tổ chức ở đâu\"\n",
    "input = \"Summarize Harry Potter\"\n",
    "input = \"Summarize the Doremon comic books?\"\n",
    "# input = \"Tóm tắt chiến tranh Việt Nam\"\n",
    "# input = \"Summarize the Vietnam War\"\n",
    "\n",
    "# input = \"Những nhà hàng ngon tại San Francisco\"\n",
    "# input = \"Which team win the world cup in 2010\"\n",
    "# input = \"Các bước làm 1 tô bún bò\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "961c067d-9102-4eb3-8dee-6b49ff7590a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ================ Controlled Variables ===============\n",
    "# = More layers and higher coeff make it stronger but will introduce more noise\n",
    "\n",
    "# layer_id = list(range(-21, -31, -1)) # 13B\n",
    "layer_id = list(range(-1, -8, -1)) # 7B\n",
    "coeff=3.0\n",
    "# coeff=0.0\n",
    "activations = {}\n",
    "for layer in layer_id:\n",
    "    activation = torch.tensor(coeff * rep_reader.directions[layer][0] * rep_reader.direction_signs[layer][0], dtype=model.dtype).to(model.device)\n",
    "    \n",
    "    activations[layer] = activation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6543e128-833d-4be4-8d29-acde574c4f9f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input: Summarize the Doremon comic books?\n",
      "output: <|begin_of_text|>Doraemon là một loạt truyện tranh Nhật Bản được sáng tạo bởi tác giả Fujiko F. Fujiyama và được phát hành từ năm 1969 đến nay. Truyện tranh này kể về cuộc sống của một con robot tên Doraemon, người đã được gửi từ tương lai để giúp đỡ một cậu bé tên Nobita Nobi.Doraemon là một con robot có khả năng bay, có thể di chuyển trong không gian và thời gian, và có nhiều công cụ và thiết bị tiên tiến khác nhau. Mỗi tập truyện tranh đều có một chủ đề khác nhau, thường xoay quanh những vấn đề của trẻ em như học tập, bạn bè, gia đình, và các vấn đề xã hội.Trong mỗi tập truyện tranh, Doraemon sẽ sử dụng các công cụ và kỹ thuật của mình để giúp đỡ Nobita và bạn bè của anh ta. Tuy nhiên, đôi khi Doraemon cũng gây ra những rắc rối và khó khăn cho Nobita và bạn bè của anh ta do sự không hiểu biết về văn hóa và xã hội hiện tại.Tóm lại, Doraemon là một loạt truyện tranh vui vẻ và ý nghĩa, mang lại nhiều giá trị giáo dục và giải trí cho trẻ em và người lớn.<|eot_id|>\n"
     ]
    }
   ],
   "source": [
    "# ============== Wrapped Controlled Model with activation addition ==============\n",
    "from repe import repe_pipeline_registry, WrappedReadingVecModel\n",
    "\n",
    "wrapped_model = WrappedReadingVecModel(model, tokenizer)\n",
    "wrapped_model.unwrap()\n",
    "wrapped_model.wrap_block(layer_id, block_name=\"decoder_block\")\n",
    "\n",
    "template = '{user_tag} {s} {assistant_tag}'\n",
    "# template = 'USER: Pretend that you are a Vietnamese assistant, answer the following request in Vietnamese: {s} ASSISTANT:'\n",
    "\n",
    "### Controlled model hidden_states:\n",
    "wrapped_model.set_controller(layer_id, activations, masks=1)\n",
    "inputs = template.format(user_tag=user_tag, assistant_tag=assistant_tag, s=input)\n",
    "encoded_inputs = tokenizer(inputs, return_tensors='pt')\n",
    "\n",
    "with torch.no_grad():\n",
    "    with torch.no_grad():\n",
    "        outputs = model.generate(**encoded_inputs.to(model.device), max_new_tokens=512, do_sample=False, repetition_penalty=1.1).detach().cpu()\n",
    "        sanity_generation = tokenizer.decode(outputs[0], skip_special_tokens=False).replace(inputs, \"\")\n",
    "wrapped_model.reset()\n",
    "wrapped_model.unwrap()\n",
    "\n",
    "print(\"input:\", input)\n",
    "print(\"output:\", sanity_generation.replace(\"\\n\", \"\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
